# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Langchain-based language model using ollama to run on the local machine."""

from collections.abc import Collection, Sequence
from typing import override

from concordia.language_model import language_model
from concordia.utils import sampling
from concordia.utils.deprecated import measurements as measurements_lib
from langchain_community.llms import ollama


_MAX_MULTIPLE_CHOICE_ATTEMPTS = 20
_DEFAULT_TEMPERATURE = 0.5
_DEFAULT_TERMINATORS = ()
_DEFAULT_SYSTEM_MESSAGE = (
    "Continue the user's sentences. Never repeat their starts. For example, "
    "when you see 'Bob is', you should continue the sentence after "
    "the word 'is'. Here are some more examples: 'Question: Is Jake a "
    "turtle?\nAnswer: Jake is ' should be completed as 'not a turtle.' and "
    "'Question: What is Priya doing right now?\nAnswer: Priya is currently ' "
    "should be completed as 'working on repairing the sink.'. Notice that "
    "it is OK to be creative with how you finish the user's sentences. The "
    'most important thing is to always continue in the same style as the user.'
)


class LangchainOllamaLanguageModel(language_model.LanguageModel):
  """Language Model that uses Ollama LLM models."""

  def __init__(
      self,
      model_name: str,
      *,
      system_message: str = _DEFAULT_SYSTEM_MESSAGE,
      measurements: measurements_lib.Measurements | None = None,
      channel: str = language_model.DEFAULT_STATS_CHANNEL,
  ) -> None:
    """Initializes the instance.

    Args:
        model_name: The language model to use. For more details, see
          https://github.com/ollama/ollama.
        system_message: System message to prefix to requests when prompting the
          model.
        measurements: The measurements object to log usage statistics to.
        channel: The channel to write the statistics to.
    """
    self._model_name = model_name
    self._system_message = system_message
    self._terminators = []
    if 'llama3' in self._model_name:
      self._terminators.extend(['<|eot_id|>'])
    self._client = ollama.Ollama(model=model_name, stop=self._terminators)

    self._measurements = measurements
    self._channel = channel

  @override
  def sample_text(
      self,
      prompt: str,
      *,
      max_tokens: int = language_model.DEFAULT_MAX_TOKENS,
      terminators: Collection[str] = _DEFAULT_TERMINATORS,
      temperature: float = _DEFAULT_TEMPERATURE,
      top_p: float = language_model.DEFAULT_TOP_P,
      top_k: int = language_model.DEFAULT_TOP_K,
      timeout: float = -1,
      seed: int | None = None,
  ) -> str:
    del max_tokens, timeout, seed, top_k  # Unused.

    prompt_with_system_message = f'{self._system_message}\n\n{prompt}'

    terminators = (
        self._terminators.extend(terminators)
        if terminators is not None
        else self._terminators
    )

    response = self._client(
        prompt_with_system_message,
        stop=terminators,
        temperature=temperature,
        top_p=top_p,
    )

    if self._measurements is not None:
      self._measurements.publish_datum(
          self._channel, {'raw_text_length': len(response)}
      )

    return response

  @override
  def sample_choice(
      self,
      prompt: str,
      responses: Sequence[str],
      *,
      seed: int | None = None,
  ) -> tuple[int, str, dict[str, float]]:
    prompt_with_system_message = f'{self._system_message}\n\n{prompt}'
    sample = ''
    answer = ''
    for attempts in range(_MAX_MULTIPLE_CHOICE_ATTEMPTS):
      # Increase temperature after the first failed attempt.
      temperature = sampling.dynamically_adjust_temperature(
          attempts, _MAX_MULTIPLE_CHOICE_ATTEMPTS
      )

      sample = self.sample_text(
          prompt_with_system_message,
          temperature=temperature,
          seed=seed,
      )
      answer = sampling.extract_choice_response(sample)
      try:
        idx = responses.index(answer)
      except ValueError:
        continue
      else:
        if self._measurements is not None:
          self._measurements.publish_datum(
              self._channel, {'choices_calls': attempts}
          )
        debug = {}
        return idx, responses[idx], debug

    raise language_model.InvalidResponseError((
        f'Too many multiple choice attempts.\nLast attempt: {sample}, '
        + f'extracted: {answer}'
    ))
